{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Analyze Errors and Explore Interpretability of Models\n",
    "\n",
    "This notebook demonstrates how to use the Responsible AI Widget's Error Analysis dashboard to understand a model trained on the Census dataset. The goal of this sample notebook is to classify income greater or less than 50K with scikit-learn and explore model errors and explanations:\n",
    "\n",
    "1. Train a LightGBM classification model using Scikit-learn\n",
    "2. Run Interpret-Community's 'explain_model' globally and locally to generate model explanations.\n",
    "3. Visualize model errors and global and local explanations with the Error Analysis visualization dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Required Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install --upgrade interpret-community\n",
    "# %pip install --upgrade raiwidgets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explain\n",
    "\n",
    "### Run model explainer at training time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import svm\n",
    "import pandas as pd\n",
    "import zipfile\n",
    "from lightgbm import LGBMClassifier\n",
    "\n",
    "# Explainer Used: Mimic Explainer\n",
    "from interpret.ext.blackbox import MimicExplainer\n",
    "from interpret.ext.glassbox import LinearExplainableModel\n",
    "from interpret.ext.glassbox import LGBMExplainableModel\n",
    "\n",
    "# You can use one of the following four interpretable models as a global surrogate to the black box model\n",
    "#from interpret.ext.glassbox import LGBMExplainableModel\n",
    "#from interpret.ext.glassbox import LinearExplainableModel\n",
    "#from interpret.ext.glassbox import SGDExplainableModel\n",
    "#from interpret.ext.glassbox import DecisionTreeExplainableModel\n",
    "\n",
    "# OR\n",
    "\n",
    "# 3. PFI Explainer\n",
    "#from interpret.ext.blackbox import PFIExplainer "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the UCI adult census income dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from raiutils.dataset import fetch_dataset\n",
    "outdirname = 'erroranalysis.12.3.20'\n",
    "zipfilename = outdirname + '.zip'\n",
    "\n",
    "fetch_dataset('https://publictestdatasets.blob.core.windows.net/data/' + zipfilename, zipfilename)\n",
    "\n",
    "with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "    unzip.extractall('.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.read_csv('adult-train.csv', skipinitialspace=True)\n",
    "test_data = pd.read_csv('adult-test.csv', skipinitialspace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "test_data_full = test_data\n",
    "test_data, _ = train_test_split(test_data, test_size=0.9, random_state=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from packaging import version\n",
    "import sklearn\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.compose import ColumnTransformer\n",
    "\n",
    "# for older scikit-learn versions use sparse, for newer sparse_output:\n",
    "if version.parse(sklearn.__version__) < version.parse('1.2'):\n",
    "    ohe_params = {\"sparse\": False}\n",
    "else:\n",
    "    ohe_params = {\"sparse_output\": False}\n",
    "\n",
    "def split_label(dataset):\n",
    "    X = dataset.drop(['income'], axis=1)\n",
    "    y = dataset[['income']]\n",
    "    return X, y\n",
    "\n",
    "def clean_data(X, y):\n",
    "    features = X.columns.values.tolist()\n",
    "    classes = y['income'].unique().tolist()\n",
    "    pipe_cfg = {\n",
    "        'num_cols': X.dtypes[X.dtypes == 'int64'].index.values.tolist(),\n",
    "        'cat_cols': X.dtypes[X.dtypes == 'object'].index.values.tolist(),\n",
    "    }\n",
    "    num_pipe = Pipeline([\n",
    "        ('num_imputer', SimpleImputer(strategy='median')),\n",
    "        ('num_scaler', StandardScaler())\n",
    "    ])\n",
    "    cat_pipe = Pipeline([\n",
    "        ('cat_imputer', SimpleImputer(strategy='constant', fill_value='?')),\n",
    "        ('cat_encoder', OneHotEncoder(handle_unknown='ignore', **ohe_params))\n",
    "    ])\n",
    "    feat_pipe = ColumnTransformer([\n",
    "        ('num_pipe', num_pipe, pipe_cfg['num_cols']),\n",
    "        ('cat_pipe', cat_pipe, pipe_cfg['cat_cols'])\n",
    "    ])\n",
    "    X = feat_pipe.fit_transform(X)\n",
    "    return X, feat_pipe, features, classes\n",
    "\n",
    "X_train_original, y_train = split_label(train_data)\n",
    "X_train, feat_pipe, features, classes = clean_data(X_train_original, y_train)\n",
    "X_test_original, y_test = split_label(test_data)\n",
    "X_test_original_full, y_test_full = split_label(test_data_full)\n",
    "y_test = y_test['income'].to_numpy()\n",
    "y_test_full = y_test_full['income'].to_numpy()\n",
    "X_test = feat_pipe.transform(X_test_original)\n",
    "features = train_data.columns.values[1:].tolist()\n",
    "classes = y_train['income'].unique().tolist()\n",
    "categorical_features = train_data.dtypes[train_data.dtypes == 'object'].index.values[1:].tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train a LightGBM classification model, which you want to analyze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = LGBMClassifier(n_estimators=1)\n",
    "model = clf.fit(X_train, y_train['income'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explain predictions on your local machine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret_community.common.constants import ShapValuesOutput, ModelTask\n",
    "# 1. Using SHAP TabularExplainer\n",
    "model_task = ModelTask.Classification\n",
    "explainer = MimicExplainer(model, X_train_original, LGBMExplainableModel,\n",
    "                           augment_data=True, max_num_of_augmentations=10,\n",
    "                           features=features, classes=classes, model_task=model_task,\n",
    "                           transformations=feat_pipe)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate global explanations\n",
    "Explain overall model predictions (global explanation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
    "# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
    "global_explanation = explainer.explain_global(X_test_original)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sorted SHAP values\n",
    "print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))\n",
    "# Corresponding feature names\n",
    "print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))\n",
    "# Feature ranks (based on original order of features)\n",
    "print('global importance rank: {}'.format(global_explanation.global_importance_rank))\n",
    "\n",
    "# Note: Do not run this cell if using PFIExplainer, it does not support per class explanations\n",
    "# Per class feature names\n",
    "print('ranked per class feature names: {}'.format(global_explanation.get_ranked_per_class_names()))\n",
    "# Per class feature importance values\n",
    "print('ranked per class feature values: {}'.format(global_explanation.get_ranked_per_class_values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print out a dictionary that holds the sorted feature importance names and values\n",
    "print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "dashboard_pipeline = Pipeline(steps=[('preprocess', feat_pipe), ('model', model)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze\n",
    "### Analyze model errors and explanations using Error Analysis dashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from raiwidgets import ErrorAnalysisDashboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Run error analysis on the full dataset with subsampled explanation data on 5k rows\n",
    "# Note in this case we need to provide the true_y_dataset parameter matching the\n",
    "# original full dataset\n",
    "ErrorAnalysisDashboard(global_explanation, dashboard_pipeline, dataset=X_test_original_full,\n",
    "                       true_y=y_test, categorical_features=categorical_features,\n",
    "                       true_y_dataset=y_test_full)"
   ]
  }
 ],
 "metadata": {
  "authors": [
   {
    "name": "mesameki"
   }
  ],
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
